package mymath;

public class MyMatrix {
    public double[][] mat;
    /**
     * 作为向量返回模长的平方
     */
    public double dotMulAsRow() {
        return dotMulAsRow(this);
    }

    /**
     * 作为向量计算点积
     */
    public double dotMulAsRow(MyMatrix other) {
        if (mat.length != 1 || other.mat.length != 1 || mat[0].length != other.mat[0].length) {
            throw new Error("Matrix dotMulForRows()");
        }
        return multiple(other.transpose()).mat[0][0];
    }

    /**
     * 用反射矩阵原理实现QR分解
     * 算法讲解：https://blog.csdn.net/ZHT2016iot/article/details/115448138
     */
    public MyMatrix[] houseHolder() {
        final int row = mat.length;
        final int col = mat[0].length;
        if (row != col) {
            throw new Error("Matrix houseHolder()");
        }
        MyMatrix mR = new MyMatrix(this);
        MyMatrix mH = MyMatrix.unitMatrix(row);
        for (int i = 0; i < row - 1; i++) {
            int len = row - i;
            MyMatrix mX = new MyMatrix(len, 1);
            for (int j = i; j < row; j++) {
                mX.mat[j - i][0] = mR.mat[j][i];
            }
            MyMatrix mY = new MyMatrix(len, 1);
            mY.mat[0][0] = Math.sqrt(mX.transpose().dotMulAsRow());
            MyMatrix mW = mX.minus(mY);
            mW = mW.multipleNumber(1 / Math.sqrt(mW.transpose().dotMulAsRow()));
            MyMatrix mGP = mW.multiple(mW.transpose()).multipleNumber(2);
            mGP = MyMatrix.unitMatrix(len).minus(mGP);
            MyMatrix mG = MyMatrix.unitMatrix(row);
            for (int j = i; j < row; j++) {
                for (int k = i; k < row; k++) {
                    mG.mat[j][k] = mGP.mat[j - i][k - i];
                }
            }
            mH = mG.multiple(mH);
            mR = mG.multiple(mR);
        }
        MyMatrix[] res = new MyMatrix[2];
        res[0] = mH.transpose();
        res[1] = mR;
        return res;
    }
    /**
     * 用QR迭代原理求特征值和特征向量
     * 算法讲解：https://blog.csdn.net/ZHT2016iot/article/details/115448138
     * @return 转置结果并排序，最终按行输出每行一个特征向量
     */
    public MyMatrix qrEgisAsRows() {
        final int row = mat.length;
        final int col = mat[0].length;
        if (row != col) {
            throw new Error("Matrix qrEgis()");
        }
        final int times = 100;
        MyMatrix mA = new MyMatrix(this);
        MyMatrix mAK = mA;
        MyMatrix mQ = MyMatrix.unitMatrix(row);
        for (int i = 0; i < times; i++) {
            MyMatrix[] qr = mA.houseHolder();
            mQ = mQ.multiple(qr[0]);
            mAK = mA;
            mA = qr[1].multiple(qr[0]);
        }
        double[] e = new double[row];
        MyMatrix[] vecs = mQ.transpose().sliceAllRows();
        for (int i = 0; i < row; i++) {
            e[i] = mAK.mat[i][i];
        }
        for (int i = 0; i < row; i++) {
            for (int j = i + 1; j < row; j++) {
                if (e[j] > e[i]) {
                    double tmp = e[i];
                    e[i] = e[j];
                    e[j] = tmp;
                    MyMatrix colTmp = vecs[i];
                    vecs[i] = vecs[j];
                    vecs[j] = colTmp;
                }
            }
        }
        return MyMatrix.fromRows(vecs);
    }

    /**
     * 拆分所有行到单行矩阵数组
     */
    public MyMatrix[] sliceAllRows() {
        final int row = mat.length;
        MyMatrix[] res = new MyMatrix[row];
        for (int i = 0; i < row; i++) {
            res[i] = sliceRow(i);
        }
        return res;
    }
    /**
     * 拆分矩阵的一行为单行矩阵
     */
    public MyMatrix sliceRow(int index) {
        final int row = 1;
        final int col = mat[0].length;
        if (index < 0 || index >= mat.length) {
            throw new Error("Matrix sliceRow()");
        }
        MyMatrix res = new MyMatrix(row, col);
        for (int j = 0; j < col; j++) {
            res.mat[0][j] = mat[index][j];
        }
        return res;
    }

    /**
     * 矩阵乘单个数
     */
    public MyMatrix multipleNumber(double k) {
        final int row = mat.length;
        final int col = mat[0].length;
        MyMatrix res = new MyMatrix(row, col);
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                res.mat[i][j] = mat[i][j] * k;
            }
        }
        return res;
    }

    /**
     * 矩阵乘法
     */
    public MyMatrix multiple(MyMatrix other) {
        if (mat[0].length != other.mat.length) {
            throw new Error("Matrix multiple()");
        }
        final int row = mat.length;
        final int vec = mat[0].length;
        final int col = other.mat[0].length;
        MyMatrix res = new MyMatrix(row, col);
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < vec; j++) {
                for (int k = 0; k < col; k++) {
                    res.mat[i][k] += mat[i][j] * other.mat[j][k];
                }
            }
        }
        return res;
    }

    /**
     * 矩阵转置
     */
    public MyMatrix transpose() {
        final int row = mat.length;
        final int col = mat[0].length;
        MyMatrix res = new MyMatrix(col, row);
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                res.mat[j][i] = mat[i][j];
            }
        }
        return res;
    }

    /**
     * 矩阵减法
     */
    public MyMatrix minus(MyMatrix other) {
        return plus(other.multipleNumber(-1));
    }

    /**
     * 矩阵加法
     */
    public MyMatrix plus(MyMatrix other) {
        if (mat.length != other.mat.length || mat[0].length != other.mat[0].length) {
            throw new Error("Matrix plus()");
        }
        final int row = mat.length;
        final int col = mat[0].length;
        MyMatrix res = new MyMatrix(row, col);
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                res.mat[i][j] = mat[i][j] + other.mat[i][j];
            }
        }
        return res;
    }

    /**
     * 从若干个单行矩阵联合为一个矩阵
     */
    public static MyMatrix fromRows(MyMatrix[] rows) {
        final int row = rows.length;
        final int col = rows[0].mat[0].length;
        MyMatrix res = new MyMatrix(row, col);
        for (int i = 0; i < row; i++) {
            if (rows[i].mat[0].length != col) {
                throw new Error("MyMatrix fromRows()");
            }
            for (int j = 0; j < col; j++) {
                res.mat[i][j] = rows[i].mat[0][j];
            }
        }
        return res;
    }

    /**
     * 从大小生成单位矩阵的工厂函数
     */
    public static MyMatrix unitMatrix(final int row) {
        final int col = row;
        MyMatrix res = new MyMatrix(row, col);
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                res.mat[i][j] = i == j ? 1 : 0;
            }
        }
        return res;
    }

    /**
     * 复制构造
     */
    public MyMatrix(MyMatrix other) {
        if (other == null) {
            throw new Error("new MyMatrix()");
        }
        final int row = other.mat.length;
        final int col = other.mat[0].length;
        mat = new double[row][col];
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                mat[i][j] = other.mat[i][j];
            }
        }
    }

    /**
     * 矩阵的序列化表达方便输出
     */
    @Override
    public String toString() {
        final int row = mat.length;
        final int col = mat[0].length;
        String s = "[";
        for (int i = 0; i < row; i++) {
            if (i > 0) {
                s += ",\n";
            }
            s += "[";
            for (int j = 0; j < col; j++) {
                if (j > 0) {
                    s += ", ";
                }
                s += String.format("%.6f", mat[i][j]);
            }
            s += "]";
        }
        s += "]";
        return s;
    }

    /**
     * 从大小构造全0矩阵
     */
    public MyMatrix(final int row, final int col) {
        if (row == 0 || col == 0) {
            throw new Error("new MyMatrix()");
        }
        mat = new double[row][col];
        for (int i = 0; i < row; i++) {
            for (int j = 0; j < col; j++) {
                mat[i][j] = 0;
            }
        }
    }
}
